# Copyright (c) 2025 Apple Inc. Licensed under MIT License.

from typing import TypedDict, Unpack


class EmbeddingAtlasOptions(TypedDict, total=False):
    """
    row_id:
        The column name for row id (if not specified, a row id column will be added).

    x:
        The column name for X axis in the embedding.

    y:
        The column name for Y axis in the embedding.

    text:
        The column name for the textual data.

    neighbors:
        The column name containing precomputed K-nearest neighbors for each point.
        Each value in the column should be a dictionary with the format:
        ``{ "ids": [id1, id2, ...], "distances": [distance1, distance2, ...] }``.

        - ``"ids"`` should be an array of row ids of the neighbors
          (if ``row_id`` is specified, match the value in row_id, otherwise use the row index),
          sorted by distance.
        - ``"distances"`` should contain the corresponding distances to each neighbor.

    labels:
        Labels for the embedding view. Set to string ``"automatic"`` to generate labels automatically, or ``"disabled"`` to disable auto labels.
        Automatic labels are generated by clustering the 2D density distribution and selecting
        representative keywords using TF-IDF ranking.
        You can also pass in a list of labels. Each label must contain ``x`` and ``y`` coordinates
        and ``text`` for the label content. Optionally, you may specify an integer ``level`` to roughly
        control the zoom level where the label appears, and `priority` for the label's priority.
        Higher priority labels have a better chance to appear when multiple labels overlap.

    stop_words:
        Stop words for automatic label generation.

    point_size:
        Override the default point size for the embedding view.

    show_table:
        Whether to display the data table when the widget opens.

    show_charts:
        Whether to display charts when the widget opens.

    show_embedding:
        Whether to display the embedding view when the widget opens.
    """

    table: str | None
    row_id: str | None
    x: str | None
    y: str | None
    text: str | None
    neighbors: str | None

    point_size: float | None

    labels: list[dict] | None
    stop_words: list[str] | None

    show_table: bool | None
    show_charts: bool | None
    show_embedding: bool | None


def make_embedding_atlas_props(**options: Unpack[EmbeddingAtlasOptions]) -> dict:
    """
    Convert the input to props of the EmbeddingAtlas view.
    """
    # Validate keys in options
    allowed_options = (
        EmbeddingAtlasOptions.__optional_keys__
        | EmbeddingAtlasOptions.__required_keys__
    )
    invalid_options = options.keys() - allowed_options

    if len(invalid_options) > 0:
        raise ValueError(
            f"The following options are not allowed for the Embedding Atlas widget: {','.join(invalid_options)}. Allowed options are {', '.join(allowed_options)}"
        )

    props: dict = {}

    def set_prop(key: str, value):
        """Set the prop with key to value, only if value is not None. Key can be a dot-separated path for nested properties"""
        if value is not None:
            parts = key.split(".")
            d = props
            for part in parts[:-1]:
                if part not in d:
                    d[part] = {}
                d = d[part]
            d[parts[-1]] = value

    # Data
    set_prop("data.table", options.get("table"))
    set_prop("data.id", options.get("row_id"))
    if options.get("x") is not None and options.get("y") is not None:
        set_prop("data.projection", {"x": options.get("x"), "y": options.get("y")})
    set_prop("data.text", options.get("text"))
    set_prop("data.neighbors", options.get("neighbors"))

    # Embedding View
    set_prop("embeddingViewConfig.pointSize", options.get("point_size"))
    set_prop("embeddingViewLabels", options.get("labels"))
    set_prop("embeddingViewConfig.autoLabelStopWords", options.get("stop_words"))

    # Layout
    set_prop("initialState.layoutStates.list.showTable", options.get("show_table"))
    set_prop("initialState.layoutStates.list.showCharts", options.get("show_charts"))
    set_prop(
        "initialState.layoutStates.list.showEmbedding", options.get("show_embedding")
    )

    set_prop("initialState.version", "0.0.0")

    return props
